{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "725baefb-a72f-4436-86b7-dbcba36091e3",
   "metadata": {},
   "source": [
    "# Finetuning Bacformer for phenotypic traits prediction tutorial\n",
    "\n",
    "This tutorial outlines how to finetune a pretrained Bacformer to predict phenotypic labels\n",
    "\n",
    "We provide a dataset containing protein sequences for over `1,000` genomes across different species, each with a binary label. We show how to train and evaluate\n",
    "finetuned Bacformer for phenotype prediction. The framework presented here is in principle extendable to any bacterial phenotype.\n",
    "\n",
    "We recommend to firstly check out the `phenotypic_traits_prediction_tutorial.ipynb`, which is significantly less computationally expensive and outlines how to train\n",
    "a simple linear regression model using precomputed Bacformer embeddings for phenotype prediction. If your phenotype is challenging or you want to provide\n",
    "your own phenotype label - please use this tutorial.\n",
    "\n",
    "Before you start, make sure you have `bacformer` installed (see README.md for details) and execute the notebook on a machine with GPU.\n",
    "\n",
    "## Step 1: Import required dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d58a114-5c01-441c-9a9a-c49ab9d82c35",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from functools import partial\n",
    "\n",
    "import torch\n",
    "from bacformer.modeling import (\n",
    "    SPECIAL_TOKENS_DICT,\n",
    "    BacformerTrainer,\n",
    "    collate_genome_samples,\n",
    "    compute_metrics_binary_genome_pred,\n",
    ")\n",
    "from bacformer.pp import dataset_col_to_bacformer_inputs\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoConfig, AutoModelForSequenceClassification, EarlyStoppingCallback, TrainingArguments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a07d306-a7c2-4f5b-b49c-50a5ded6d450",
   "metadata": {},
   "source": [
    "## Step 2: Load the dataset¶\n",
    "\n",
    "We will be using the dataset for predicting `Catalase`. `Catalase` denotes whether a bacterium produces the catalase enzyme that breaks down hydrogen peroxide (H₂O₂) into water and oxygen, thereby protecting the cell from oxidative stress. The phenotypic data has been collected from [1] and is a binary classification problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcbfa844-5292-45bf-949b-b0ddbc305a4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"macwiatrak/phenotypic-trait-catalase-protein-sequences\", keep_in_memory=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56be1905-1a3f-4d6f-a829-2011e66980d5",
   "metadata": {},
   "source": [
    "## Step 3: Embed the dataset with the base protein language model (pLM)\n",
    "\n",
    "The first step to using Bacformer is embedding the protein sequences with the base pLM model which is [ESM-2 t12 35M](https://huggingface.co/facebook/esm2_t12_35M_UR50D).\n",
    "\n",
    "This step should take ~1.5h on a single A100 NVIDIA GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a2a8432-1394-4a20-bc86-d7ddcac1bbb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embed the protein sequences with the ESM-2 base model\n",
    "for split_name in dataset.keys():\n",
    "    dataset[split_name] = dataset[split_name].select(range(30))\n",
    "    dataset[split_name] = dataset_col_to_bacformer_inputs(\n",
    "        dataset=dataset[split_name],\n",
    "        max_n_proteins=7000,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94ae2944-40c4-48dd-8087-c91703d3d58f",
   "metadata": {},
   "source": [
    "## Step 4: Load the Bacformer model\n",
    "\n",
    "Load the pre-trained Bacformer model with a classification layer on top which we finetune."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6aa09e6-3a6f-4140-af41-700a564314f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the Bacformer model for genome classification\n",
    "# for this task we use the Bacformer model trained on masked complete genomes\n",
    "config = AutoConfig.from_pretrained(\"macwiatrak/bacformer-masked-complete-genomes\", trust_remote_code=True)\n",
    "config.num_labels = 1\n",
    "config.problem_type = \"binary_classification\"\n",
    "bacformer_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    \"macwiatrak/bacformer-masked-complete-genomes\", config=config, trust_remote_code=True\n",
    ").to(torch.bfloat16)\n",
    "print(\"Nr of parameters:\", sum(p.numel() for p in bacformer_model.parameters()))\n",
    "print(\"Nr of trainable parameters:\", sum(p.numel() for p in bacformer_model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdb205e0-5a29-4a5b-82e5-52e8299ff511",
   "metadata": {},
   "source": [
    "## Step 5: Setup the trainer for finetuning\n",
    "\n",
    "Setup a trainer object to allow for finetuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f14ac1d5-b60a-49f0-8b75-715145dabcb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a trainer\n",
    "# get training args\n",
    "output_dir = \"output/pheno_trait_pred\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    save_total_limit=1,\n",
    "    learning_rate=0.00015,\n",
    "    num_train_epochs=5,\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    gradient_accumulation_steps=8,\n",
    "    seed=1,\n",
    "    dataloader_num_workers=4,\n",
    "    bf16=True,\n",
    "    metric_for_best_model=\"eval_auroc\",\n",
    "    load_best_model_at_end=True,\n",
    "    greater_is_better=True,\n",
    ")\n",
    "\n",
    "# define a collate function for the dataset\n",
    "collate_genome_samples_fn = partial(collate_genome_samples, SPECIAL_TOKENS_DICT[\"PAD\"], 7000, 1000)\n",
    "trainer = BacformerTrainer(\n",
    "    model=bacformer_model,\n",
    "    data_collator=collate_genome_samples_fn,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"validation\"],\n",
    "    args=training_args,\n",
    "    compute_metrics=compute_metrics_binary_genome_pred,\n",
    "    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3af49450-c1a8-4549-9288-8e55728facab",
   "metadata": {},
   "source": [
    "## Step 6: Finetune the model 🎉🚂😎\n",
    "\n",
    "Finetune the model to predict gene essentiality. The training should take ~15 min on a single A100 NVIDIA GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07af5dd0-c399-4d79-8f6e-8af746304b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the model\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7ded8af-f3b7-444b-84c5-21916da72622",
   "metadata": {},
   "source": [
    "## Step 7: Evaluate on the test and run predictions\n",
    "\n",
    "Having a trained model you can now evaluate the model on the test set and run the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d68b84-b1d3-485d-b7bf-c1024f0029f9",
   "metadata": {},
   "outputs": [],
   "source": [
    " # evaluate the model on the test set\n",
    "test_output = trainer.predict(dataset[\"test\"])\n",
    "print(\"Test output:\", test_output.metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5193ee50-c4e4-43a0-89b9-d60a573b0ffc",
   "metadata": {},
   "source": [
    "## [Optional] Step 8: Plot genome-phenotype probabilities\n",
    "\n",
    "Plot predicted phenotype probability for test genomes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f61aecc9-4e65-4623-b6bd-7acdd496785e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the pandas DataFrame with data to plot\n",
    "plot_df = pd.DataFrame(\n",
    "    {\n",
    "        'probability': test_output.predictions.squeeze().tolist(),\n",
    "        'label': test_output.label_ids.tolist(),\n",
    "        'genome_name': dataset[\"test\"][\"genome_name\"],\n",
    "    }\n",
    ")\n",
    "\n",
    "# Create the KDE plot\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.kdeplot(\n",
    "    plot_df,\n",
    "    x='probability',\n",
    "    fill=True,\n",
    "    hue='label',\n",
    "    # color='blue',\n",
    "    alpha=0.6,\n",
    "    # label='Non-essential genes'\n",
    ")\n",
    "\n",
    "# Add legend and labels\n",
    "plt.title(f\"Genome phenotype (Catalase) prediction\", fontsize=18)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acc6fb02-0b52-45b6-a078-93a4d6b4075f",
   "metadata": {},
   "source": [
    "----------------------\n",
    "#### Voilà, you made it 👏! \n",
    "\n",
    "This example shows how to finetune Bacformer for a genome-level task and can be applied to any other phenotype with available genomes and phenotypes.\n",
    "\n",
    "In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues.\n",
    "\n",
    "We also provide `139` diverse phenotypic trait labels distributed across almost 25k genomes. To use it please see `phenotypic_traits_prediction_tutorial.ipynb`, which outlines how \n",
    "to train a linear regression model on top of precomputed genome embeddings.\n",
    "\n",
    "## References\n",
    "\n",
    "[1] Weimann, Aaron, et al. \"From genomes to phenotypes: Traitar, the microbial trait analyzer"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
